Keep track of DeathMonitor cookies This change keeps track of the objects that the cookies points to so the serviceDied callback knows when it can use the cookie. Test: atest neuralnetworks_utils_hal_aidl_test Tets: atest NeuralNetworksTest_static Bug: 319210610 (cherry picked from https://googleplex-android-review.googlesource.com/q/commit:def7a3cf59fa17ba7faa9af14a24f4161bc276bd) (cherry picked from https://googleplex-android-review.googlesource.com/q/commit:49859a3b5542270363efe42a56b9145142bbfa60) Merged-In: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794 Change-Id: I418cbc6baa19aa702d9fd2e7d8096fe1a02b7794 
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h index 92ed1cd..9a7fe5e 100644 --- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h +++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h 
@@ -56,6 +56,8 @@  // Thread safe class  class DeathMonitor final {  public: + explicit DeathMonitor(uintptr_t cookieKey) : kCookieKey(cookieKey) {} +  static void serviceDied(void* cookie);  void serviceDied();  // Precondition: `killable` must be non-null. @@ -63,9 +65,18 @@  // Precondition: `killable` must be non-null.  void remove(IProtectedCallback* killable) const;   + uintptr_t getCookieKey() const { return kCookieKey; } + + ~DeathMonitor(); + DeathMonitor(const DeathMonitor&) = delete; + DeathMonitor(DeathMonitor&&) noexcept = delete; + DeathMonitor& operator=(const DeathMonitor&) = delete; + DeathMonitor& operator=(DeathMonitor&&) noexcept = delete; +  private:  mutable std::mutex mMutex;  mutable std::vector<IProtectedCallback*> mObjects GUARDED_BY(mMutex); + const uintptr_t kCookieKey;  };    class DeathHandler final { 
diff --git a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp index 54a673c..4a7ac08 100644 --- a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp +++ b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp 
@@ -25,6 +25,7 @@    #include <algorithm>  #include <functional> +#include <map>  #include <memory>  #include <mutex>  #include <vector> @@ -33,6 +34,16 @@    namespace aidl::android::hardware::neuralnetworks::utils {   +namespace { + +// Only dereference the cookie if it's valid (if it's in this set) +// Only used with ndk +std::mutex sCookiesMutex; +uintptr_t sCookieKeyCounter GUARDED_BY(sCookiesMutex) = 0; +std::map<uintptr_t, std::weak_ptr<DeathMonitor>> sCookies GUARDED_BY(sCookiesMutex); + +} // namespace +  void DeathMonitor::serviceDied() {  std::lock_guard guard(mMutex);  std::for_each(mObjects.begin(), mObjects.end(), @@ -40,8 +51,24 @@  }    void DeathMonitor::serviceDied(void* cookie) { - auto deathMonitor = static_cast<DeathMonitor*>(cookie); - deathMonitor->serviceDied(); + std::shared_ptr<DeathMonitor> monitor; + { + std::lock_guard<std::mutex> guard(sCookiesMutex); + if (auto it = sCookies.find(reinterpret_cast<uintptr_t>(cookie)); it != sCookies.end()) { + monitor = it->second.lock(); + sCookies.erase(it); + } else { + LOG(INFO) + << "Service died, but cookie is no longer valid so there is nothing to notify."; + return; + } + } + if (monitor) { + LOG(INFO) << "Notifying DeathMonitor from serviceDied."; + monitor->serviceDied(); + } else { + LOG(INFO) << "Tried to notify DeathMonitor from serviceDied but could not promote."; + }  }    void DeathMonitor::add(IProtectedCallback* killable) const { @@ -57,12 +84,25 @@  mObjects.erase(removedIter);  }   +DeathMonitor::~DeathMonitor() { + // lock must be taken so object is not used in OnBinderDied" + std::lock_guard<std::mutex> guard(sCookiesMutex); + sCookies.erase(kCookieKey); +} +  nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInterface> object) {  if (object == nullptr) {  return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)  << "utils::DeathHandler::create must have non-null object";  } - auto deathMonitor = std::make_shared<DeathMonitor>(); + + std::shared_ptr<DeathMonitor> deathMonitor; + { + std::lock_guard<std::mutex> guard(sCookiesMutex); + deathMonitor = std::make_shared<DeathMonitor>(sCookieKeyCounter++); + sCookies[deathMonitor->getCookieKey()] = deathMonitor; + } +  auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient(  AIBinder_DeathRecipient_new(DeathMonitor::serviceDied));   @@ -70,8 +110,9 @@  // STATUS_INVALID_OPERATION. We ignore this case because we only use local binders in tests  // where this is not an error.  if (object->isRemote()) { - const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath( - object->asBinder().get(), deathRecipient.get(), deathMonitor.get())); + const auto ret = ndk::ScopedAStatus::fromStatus( + AIBinder_linkToDeath(object->asBinder().get(), deathRecipient.get(), + reinterpret_cast<void*>(deathMonitor->getCookieKey())));  HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed";  }   @@ -91,8 +132,9 @@    DeathHandler::~DeathHandler() {  if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) { - const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath( - kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get())); + const auto ret = ndk::ScopedAStatus::fromStatus( + AIBinder_unlinkToDeath(kObject->asBinder().get(), kDeathRecipient.get(), + reinterpret_cast<void*>(kDeathMonitor->getCookieKey())));  const auto maybeSuccess = handleTransportError(ret);  if (!maybeSuccess.ok()) {  LOG(ERROR) << maybeSuccess.error().message; 
diff --git a/neuralnetworks/aidl/utils/test/DeviceTest.cpp b/neuralnetworks/aidl/utils/test/DeviceTest.cpp index 73727b3..ffd3b8e 100644 --- a/neuralnetworks/aidl/utils/test/DeviceTest.cpp +++ b/neuralnetworks/aidl/utils/test/DeviceTest.cpp 
@@ -697,7 +697,8 @@  const auto mockDevice = createMockDevice();  const auto device = Device::create(kName, mockDevice, kVersion).value();  const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));  return ndk::ScopedAStatus::ok();  };  EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _)) @@ -846,7 +847,8 @@  const auto mockDevice = createMockDevice();  const auto device = Device::create(kName, mockDevice, kVersion).value();  const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));  return ndk::ScopedAStatus::ok();  };  EXPECT_CALL(*mockDevice, prepareModelWithConfig(_, _, _)) @@ -970,7 +972,8 @@  const auto mockDevice = createMockDevice();  const auto device = Device::create(kName, mockDevice, kVersion).value();  const auto ret = [&device]() { - DeathMonitor::serviceDied(device->getDeathMonitor()); + DeathMonitor::serviceDied( + reinterpret_cast<void*>(device->getDeathMonitor()->getCookieKey()));  return ndk::ScopedAStatus::ok();  };  EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))